import random
import torch

from DataHandling.FileBasedDatasetBase import FileBasedDataset
from Utils import logger, config as cfg
from Utils.utils import get_model_id_from_file_name
from Utils.Constants import Diff as DiffConst


class SelectionAgentDataloader:
    def __init__(self, maps_dataset: FileBasedDataset, initial_num_steps_to_use: int, sequence_size_increase: int,
                 batch_size: int):
        """
        !!!! This is an ugly implementation work around on dataloader for it to be used in my way. !!!!
        This dataloader is used by the agent to retrieve the data for predicting models results and it also handles.
        This dataloader is dynamic - meaning this dataloader should be updated every time models are removed by the
        agent to allow getting only data about the relevant models.
        Differences from normal Dataloader:
            * len - the number of available models in dataloader (not number of steps done when iterating over it)
            * __next__ - returns a batch of data instead of a single data point
        :param: maps_dataset
        :param initial_num_steps_to_use: The size of the first sequence used by the agent (minimal number of CNN steps
                    used by selection agent for eval = first number of CNN steps used for eval).
        :param sequence_size_increase:  How many more CNN steps are used by the agent after every cut. e.g., if
                min_models_steps_for_eval=1 and sequence_size_increase_between_cuts=4. In the first agent iteration
                data that will be used is CNN data after first step (100 mini batches), in second agent iteration the
                CNN step will be 5(=4+1) (500 mini batch steps).
        :param batch_size:
        :param shuffle:
        """
        self._dataset = maps_dataset
        self._models_id_to_index = {curr_id: idx for idx, curr_id in enumerate(self._dataset.models_ids)}
        self._sequence_size = initial_num_steps_to_use
        # self._shuffle = shuffle
        self._batch_size = batch_size
        self._update_data()
        self.sequence_size_increase = sequence_size_increase

    def _update_data(self):
        self._size = len(self._models_id_to_index)
        self._last_index = 0
        self._models_id_keys = list(self._models_id_to_index.keys())
        # if self._shuffle:
        #     random.shuffle(self._models_id_keys)

    def __len__(self):
        return self._size

    @property
    def current_models_ids(self):
        return self._models_id_keys.copy()

    @property
    def original_models_results(self):
        return self._dataset.models_ids, self._dataset.models_results

    def reset(self):
        self._update_data()

    def remove_files(self, removed_files=None, removed_ids=None):
        if removed_files is not None:
            removed_ids = [get_model_id_from_file_name(file) for file in removed_files]
        for curr_id in removed_ids:
            self._models_id_to_index.pop(curr_id)
        self._update_data()
        logger().log('SelectionAgentDataloader::remove_files', 'Removing ids: ', removed_ids, ' New size: ', len(self))

    def get_batch(self):
        models_ids_in_batch = self._models_id_keys[self._last_index: self._last_index + self._batch_size]
        x_vals = list()
        y_vals = list()
        for curr_key in models_ids_in_batch:
            curr_x, curr_y = self._dataset[self._models_id_to_index[curr_key]]
            curr_x = curr_x[:self._sequence_size, :]
            x_vals.append(curr_x)
            y_vals.append(curr_y)

        self._last_index += self._batch_size
        return torch.tensor(x_vals, device=cfg.device), torch.tensor(y_vals, device=cfg.device)

    def __iter__(self):
        return self

    def __next__(self):
        if self._last_index >= len(self):
            new_sequence_size = self._sequence_size + self.sequence_size_increase
            if new_sequence_size >= DiffConst.NUMBER_STEPS_SAVED:
                logger().warning(f'Got to maximal number of data step sequence size will be: {DiffConst.NUMBER_STEPS_SAVED}')
            self._sequence_size = new_sequence_size
            raise StopIteration
        return self.get_batch()
